/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution

import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import com.google.common.io.ByteStreams

import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.PlatformDependent

/**
 * Serializer for serializing [[UnsafeRow]]s during shuffle. Since UnsafeRows are already stored as
 * bytes, this serializer simply copies those bytes to the underlying output stream. When
 * deserializing a stream of rows, instances of this serializer mutate and return a single UnsafeRow
 * instance that is backed by an on-heap byte array.
 *
 * Note that this serializer implements only the [[Serializer]] methods that are used during
 * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException.
 *
 * This serializer does not support UnsafeRows that use
 * [[org.apache.spark.sql.catalyst.util.ObjectPool]].
 *
 * @param numFields the number of fields in the row being serialized.
 */
private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable {
  override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields)
  override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
}

private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance {

  private[this] val EOF: Int = -1

  override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
    private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
    private[this] val dOut: DataOutputStream = new DataOutputStream(out)

    override def writeValue[T: ClassTag](value: T): SerializationStream = {
      val row = value.asInstanceOf[UnsafeRow]
      assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool")
      dOut.writeInt(row.getSizeInBytes)
      var dataRemaining: Int = row.getSizeInBytes
      val baseObject = row.getBaseObject
      var rowReadPosition: Long = row.getBaseOffset
      while (dataRemaining > 0) {
        val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining)
        PlatformDependent.copyMemory(
          baseObject,
          rowReadPosition,
          writeBuffer,
          PlatformDependent.BYTE_ARRAY_OFFSET,
          toTransfer)
        out.write(writeBuffer, 0, toTransfer)
        rowReadPosition += toTransfer
        dataRemaining -= toTransfer
      }
      this
    }
    override def writeKey[T: ClassTag](key: T): SerializationStream = {
      assert(key.isInstanceOf[Int])
      this
    }
    override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream =
      throw new UnsupportedOperationException
    override def writeObject[T: ClassTag](t: T): SerializationStream =
      throw new UnsupportedOperationException
    override def flush(): Unit = dOut.flush()
    override def close(): Unit = {
      writeBuffer = null
      dOut.writeInt(EOF)
      dOut.close()
    }
  }

  override def deserializeStream(in: InputStream): DeserializationStream = {
    new DeserializationStream {
      private[this] val dIn: DataInputStream = new DataInputStream(in)
      private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024)
      private[this] var row: UnsafeRow = new UnsafeRow()
      private[this] var rowTuple: (Int, UnsafeRow) = (0, row)

      override def asKeyValueIterator: Iterator[(Int, UnsafeRow)] = {
        new Iterator[(Int, UnsafeRow)] {
          private[this] var rowSize: Int = dIn.readInt()

          override def hasNext: Boolean = rowSize != EOF

          override def next(): (Int, UnsafeRow) = {
            if (rowBuffer.length < rowSize) {
              rowBuffer = new Array[Byte](rowSize)
            }
            ByteStreams.readFully(in, rowBuffer, 0, rowSize)
            row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null)
            rowSize = dIn.readInt() // read the next row's size
            if (rowSize == EOF) { // We are returning the last row in this stream
              val _rowTuple = rowTuple
              // Null these out so that the byte array can be garbage collected once the entire
              // iterator has been consumed
              row = null
              rowBuffer = null
              rowTuple = null
              _rowTuple
            } else {
              rowTuple
            }
          }
        }
      }
      override def asIterator: Iterator[Any] = throw new UnsupportedOperationException
      override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException
      override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException
      override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException
      override def close(): Unit = dIn.close()
    }
  }

  override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
  override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
    throw new UnsupportedOperationException
  override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
    throw new UnsupportedOperationException
}
